Learned Routing + Two-Pass N-gram Rescoring + Extended Orders (2-12)#860
Open
pappanick wants to merge 3 commits intoopenai:mainfrom
Open
Learned Routing + Two-Pass N-gram Rescoring + Extended Orders (2-12)#860pappanick wants to merge 3 commits intoopenai:mainfrom
pappanick wants to merge 3 commits intoopenai:mainfrom
Conversation
…ders Combines PR openai#834's learned multi-expert routing head with PR openai#846's two-pass cold-cache rescoring. Key changes: - Extended n-gram orders from 2-7 to 2-12 with 8M bucket hash tables - Two-pass eval: rescore first 15 chunks with full cache after pass 1 - Per-chunk loss tracking for precise pass-1/pass-2 delta computation - Configurable via env vars: NGRAM_MAX_ORDER, NGRAM_BUCKETS, TWO_PASS_ENABLED, TWO_PASS_RESCORE_CHUNKS Based on PR openai#834 (AnirudhRahul) + PR openai#846 (himanshudongre) stack.
- Per-head learned gate in attention (PR openai#638/openai#733): -0.002 BPB - Lambda_v * x0 shortcut from initial embedding (PR openai#657/openai#733): -0.002 BPB - Both enabled by default via GATED_ATTENTION=1, VALUE_RESIDUAL=1 - Added attn_gate, lambda_v to control tensor patterns for proper quantization handling - All smoke tests pass on CPU
…eader Major additions: - Depth recurrence: layers 4,5 repeated -> 13 virtual from 11 physical Repeat blocks share heavy CastedLinear weights, own scalar params untie_recurrence() deep-copies before TTT for independent specialization Only ~1% param overhead during training - TTT defaults changed to match PR openai#733 winning recipe: - SGD optimizer (was AdamW) - simpler, less memory - lr=0.002 (was 0.0005) - higher for SGD - Unfreeze all 11 blocks (was 2) - more params for adaptation - All repeat_blocks params unfrozen for TTT Configurable via: RECUR_LAYERS="4,5" TTT_OPTIMIZER=sgd TTT_LR=0.002 All smoke tests pass on CPU (syntax, recurrence, weight sharing, untie).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Combines techniques from PRs #834, #846, #733, and #693 into a single submission with 9 innovations.
Techniques
Linear(512,12)Architecture
PR #834/414 stack: 11 physical layers (13 virtual via depth recurrence), 512d, 8H, 8KV, LeakyReLU(0.5)^2, U-Net skips, SmearGate, BigramHash(6144), Partial RoPE (16/64), XSA all layers, VE128 on layers 9-10, EMA+SWA, GPTQ int5 + zstd-22.
Key Innovation: Depth Recurrence without Bank Refactor
Instead of PR #733's parameter bank approach, we use shared module references: repeat blocks share CastedLinear weights from physical blocks but own independent scalar params (attn_scale, mlp_scale, attn_gate, lambda_v). Before TTT,
untie_recurrence()deep-copies the heavy weights so repeat layers can specialize independently. ~1% param overhead during training, full independence during TTT.Two-Pass Rescoring
Pass 1: Standard sequential chunk eval with causal n-gram cache building.
Pass 2: Rescore first 15 chunks with the full cache (no updates). Early chunks improve dramatically since their n-gram experts now have full context. Per-chunk loss tracking enables precise delta computation.
Status
Run Command
RECUR_LAYERS="4,5" GATED_ATTENTION=1 VALUE_RESIDUAL=1 \ TWO_PASS_ENABLED=1 TWO_PASS_RESCORE_CHUNKS=15 \ NGRAM_MAX_ORDER=12 NGRAM_BUCKETS=8388608 \ TTT_OPTIMIZER=sgd TTT_LR=0.002 TTT_FREEZE_BLOCKS=11 \ NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=6144 XSA_LAST_N=11 \ torchrun --standalone --nproc_per_node=8 train_gpt.pyCredits